import spacy
from tqdm import tqdm
import random
import openai
import os
import time
from openai import OpenAI
import json
import nltk
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed

from utils.util import read_json, write_json, read_txt
from utils.token_count_decorator import token_count_decorator

class Extract:
    def __init__(self, input_data_path, store_path):
        self.sentences_arr = read_json(input_data_path)
        self.store_path = store_path
        self.nlp = spacy.load("en_core_web_trf")
        self.concurrent = 500
        self.opcode_extraction_prompt = read_txt("dsl_design/data/prompt/opcode_extraction_3.txt")
        self.batch_input_path = "dsl_design/data/temp_batch/extract_batch_input.jsonl"
        self.batch_output_path = "dsl_design/data/temp_batch/extract_batch_output.jsonl"
        self.checkpoint_path = "dsl_design/data/cache/checkpoint.json"
        self.lemmatizer = WordNetLemmatizer()
        self.stop_words = []

        if os.path.exists(self.checkpoint_path):
            with open(self.checkpoint_path, 'r') as cp_file:
                checkpoint_data = json.load(cp_file)
            self.start_index = checkpoint_data.get('last_index', 0)
            self.prefix_data = checkpoint_data.get('prefix_data', [])
        else:
            self.start_index = 0
            self.prefix_data = []

    def extract(self):
        # 检查是否存在已保存的部分结果
        if os.path.exists(self.store_path):
            with open(self.store_path, 'r') as f:
                prefix_data = json.load(f)
            processed_count = len(prefix_data)  # 已处理的句子数组数量
        else:
            prefix_data = []
            processed_count = 0

        sentences_len = 0
        extracted_len = 0

        def process_sentence(sentence):
            sentence_fix = {
                "sentence": sentence,
                "opcode": ""
            }
            sentence_doc = self.nlp(sentence)
            for token in sentence_doc:
                if token.pos_ == "VERB":
                    sentence_fix["opcode"] = token.lemma_.upper()
                    break
            return sentence_fix

        def process_sentences(sentences):
            nonlocal sentences_len, extracted_len
            sentences_len += len(sentences)
            prefix = []
            for sentence in sentences:
                sentence_fix = process_sentence(sentence)
                if sentence_fix["opcode"] != "":
                    prefix.append(sentence_fix)
            extracted_len += len(prefix)
            return prefix

        with ThreadPoolExecutor() as executor:
            futures = [
                executor.submit(process_sentences, sentences)
                for i, sentences in enumerate(self.sentences_arr)
                if i >= processed_count  # 跳过已处理的部分
            ]

            for i, future in tqdm(enumerate(as_completed(futures)), total=len(futures)):
                prefix_data.append(future.result())

                # 每处理 100 组句子，保存一次结果
                if (i + processed_count + 1) % 100 == 0 or (i + processed_count + 1) == len(self.sentences_arr):
                    with open(self.store_path, 'w', encoding="utf8") as f:
                        json.dump(prefix_data, f, ensure_ascii=False, indent=4)

        print("Sentences: ", sentences_len)
        print("Extracted: ", extracted_len)
        print("Fail: ", sentences_len - extracted_len)
        print("Fail %: ", format((sentences_len - extracted_len) / sentences_len * 100, ".2f"), "%")

    def filter_extract_data(self, threshold=50):
        with open("dsl_design/data/stop_word.txt", "r") as file:
            for line in file:
                self.stop_words.append(line.strip().lower())
        
        data_extracted = read_json(self.store_path)
        opcodes = [sentence["opcode"] for protocol in data_extracted for sentence in protocol]
        opcode_count = Counter(opcodes)
        print("Total opcodes", len(opcode_count))
        opcodes = list(opcode_count.keys())

        # 删除停用词
        opcodes = [
            opcode for opcode in opcodes 
            if (lemma_opcode := self.__lemmarize(opcode, pos="v")) not in self.stop_words
            and len(lemma_opcode) > 2
        ]
        print("Delete stop words", len(opcodes))

        # 筛选 opcode
        opcodes = [opcode for opcode in opcodes if opcode_count[opcode] > threshold]
        print("Filter operations", len(opcodes))

        # Remove sentences with opcodes that are not in the filtered list
        filtered_data = []
        for protocol in data_extracted:
            filtered_protocol = [sentence for sentence in protocol if sentence["opcode"] in opcodes]
            filtered_data.append(filtered_protocol)

        # Save the filtered protocols to a new JSON file
        base_path = self.store_path.rsplit(".json", 1)[0]
        write_json(f"{base_path}_filtered.json", filtered_data)

    def sample_sentences(self, threshold=500):
        base_path = self.store_path.rsplit(".json", 1)[0]
        data_extracted = read_json(f"{base_path}_filtered.json")
        # Group sentences by opcode and count across all protocols
        opcode_sentences = defaultdict(set)
        
        for protocol in data_extracted:
            for sentence in protocol:
                opcode_sentences[sentence["opcode"]].add(sentence["sentence"])

        # Sample up to 500 sentences per opcode and count total sentences per opcode
        total_sampled_sentences = {}
        for opcode, sentences in opcode_sentences.items():
            if len(sentences) > threshold:
                sampled_sentences = random.sample(sentences, threshold)
            else:
                sampled_sentences = list(sentences)
            total_sampled_sentences[opcode] = sampled_sentences
        
        # Filter the original data by keeping only sampled sentences
        filtered_protocols = []
        
        for protocol in data_extracted:
            filtered_protocol = []
            for sentence in protocol:
                if sentence["sentence"] in (sentences := total_sampled_sentences[sentence["opcode"]]):
                    filtered_protocol.append(sentence)
                    sentences.remove(sentence["sentence"])

            # Add the protocol if it has remaining sentences
            if filtered_protocol:
                filtered_protocols.append(filtered_protocol)
            
        write_json(f"{base_path}_filtered.json", filtered_protocols)

    def extract_2(self):
        # 使用 NLTK
        prefix_data = []
        sentences_len = 0
        extracted_len = 0
        lemmatizer = WordNetLemmatizer()
        for sentences in tqdm(self.sentences_arr):
            prefix = []
            sentences_len += len(sentences)
            for sentence in sentences:
                sentence_fix = {
                    "sentence": sentence,
                    "opcode": ""
                }
                tokens = word_tokenize(sentence)
                # 进行词性标注
                tagged = nltk.pos_tag(tokens)
                for word, tag in tagged:
                    if tag in ['VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ'] and word not in ["is", "are", "was", "were", "am"]:
                        lemma_verb = lemmatizer.lemmatize(word, 'v')
                        sentence_fix["opcode"] = lemma_verb.upper()
                        break
                if sentence_fix["opcode"] != "":
                    prefix.append(sentence_fix)
                else:
                    print("No verb found in sentence")
                    print(sentence)
            prefix_data.append(prefix)
            extracted_len += len(prefix)
        print("Sentences: ", sentences_len)
        print("Extracted: ", extracted_len)
        print("Fail: ", sentences_len - extracted_len)
        print("Fail %: ", format((sentences_len - extracted_len) / sentences_len * 100, ".2f"), "%")
        write_json(self.store_path, prefix_data)

    def extract_3(self):
        # 使用 gpt
        sentences_len = 0
        extracted_len = 0

        for k in tqdm(range(self.start_index, len(self.sentences_arr), self.concurrent)):
            sentences_raw = self.sentences_arr[k:k+self.concurrent]
            sentences = [sentence for sentence_raw in sentences_raw for sentence in sentence_raw]
            prefix_candidate = self.__prepare_prefix_candidate(sentences_raw)

            self.__empty_jsonl_contents()
            sentences_len += len(sentences)

            for i in range(0, len(sentences)):
                self.__gpt_batch_store(self.opcode_extraction_prompt.replace("---SENTENCES---", sentences[i]), str(i))
            print("Batch stored")

            batch_obj = self.__gpt_batch_call()
            print("Batch called, waiting for results...")

            results = self.__get_batch_result(batch_obj.id)
            print("Results received")

            self.__process_results(results, prefix_candidate)

            cleaned_prefix = self.__clean_prefix_candidate(prefix_candidate)
            self.prefix_data.extend(cleaned_prefix)
            extracted_len += sum(len(inner) for inner in cleaned_prefix)
            self.__save_checkpoint(k)

            time.sleep(8)

        print("Sentences: ", sentences_len)
        print("Extracted: ", extracted_len)
        print("Fail: ", sentences_len - extracted_len)
        print("Fail %: ", format((sentences_len - extracted_len) / sentences_len * 100, ".2f"), "%")
        write_json(self.store_path, self.prefix_data)

    def __prepare_prefix_candidate(self, sentences_raw):
        prefix_candidate = []
        for sentence_raw in sentences_raw:
            prefix_candidate.append([])
            for sentence in sentence_raw:
                prefix_candidate[-1].append({
                    "sentence": sentence,
                    "opcode": "NONE"
                })
        return prefix_candidate

    def __process_results(self, results, prefix_candidate):
        for result in results:
            index = int(result["custom_id"])
            opcode = result["text"]
            outer_index = 0
            inner_index = 0
            for m in range(len(prefix_candidate)):
                if index >= len(prefix_candidate[m]):
                    index -= len(prefix_candidate[m])
                else:
                    outer_index = m
                    inner_index = index
                    break
            prefix_candidate[outer_index][inner_index]["opcode"] = opcode

    def __clean_prefix_candidate(self, prefix_candidate):
        cleaned_prefix = []
        for a in range(len(prefix_candidate)):
            prefix_candidate_inner = prefix_candidate[a]
            prefix_candidate_inner = [prefix for prefix in prefix_candidate_inner if prefix["opcode"] != "NONE"]
            cleaned_prefix.append(prefix_candidate_inner)
        return cleaned_prefix

    def __save_checkpoint(self, last_index):
        # 保存当前进度
        checkpoint_data = {
            "last_index": last_index,
            "prefix_data": self.prefix_data
        }
        with open(self.checkpoint_path, 'w', encoding="utf8") as cp_file:
            json.dump(checkpoint_data, cp_file)

    @token_count_decorator
    def __gpt_batch_store(self, content, index):
        standard = {"custom_id": "", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a natural language processing model designed for performing NLP tasks."},{"role": "user", "content": ""}],"max_tokens": 1000}}
        prompt_unit = standard.copy()
        prompt_unit["body"]["messages"][1]["content"] = content
        prompt_unit["custom_id"] = index
        with open(self.batch_input_path, 'a') as file:
            # 将字典转换为JSON字符串并追加到文件
            json_line = json.dumps(prompt_unit)
            file.write(json_line + '\n')

    def __gpt_batch_call(self):
        client = OpenAI()
        batch_input_file = client.files.create(
            file=open(self.batch_input_path, "rb"),
            purpose="batch"
        )
        batch_input_file_id = batch_input_file.id
        batch_obj = client.batches.create(
            input_file_id=batch_input_file_id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
        )
        return batch_obj
    
    def __get_batch_result(self, batch_id):
        client = OpenAI()
        results_return = []
        while True:
            batch = client.batches.retrieve(batch_id)
            if batch.status == "completed":
                result_file_id = batch.output_file_id
                result = client.files.content(result_file_id).content
                result_file_name = self.batch_output_path
                with open(result_file_name, 'wb') as file:
                    file.write(result)
                results = []
                with open(result_file_name, 'r') as file:
                    for line in file:
                        # Parsing the JSON string into a dict and appending to the list of results
                        json_object = json.loads(line.strip())
                        results.append(json_object)
                # print("results: ", results)
                for r in results:
                    lines = r["response"]["body"]["choices"][0]["message"]["content"].split("\n")
                    for line in lines:
                        if self.__word_count(line.strip()) == 1:
                            results_return.append({
                                "custom_id": r["custom_id"],
                                "text": line.strip()
                            })
                            break
                    else:
                        results_return.append({
                            "custom_id": r["custom_id"],
                            "text": "NONE"
                        })
                return results_return
            elif batch.status == "failed" :
                print("Batch failed")
                return []
            elif batch.status == "expired":
                print("Batch expired")
                return []
            elif batch.status == "cancelled":
                print("Batch cancelled")
                return []
            elif batch.status == "cancelling":
                print("Batch cancelling")
                return []
            else:
                print(batch.status)
                time.sleep(3)

    def __empty_jsonl_contents(self):
        if os.path.exists(self.batch_input_path):
            with open(self.batch_input_path, 'w') as file:
                file.write('')
        if os.path.exists(self.batch_output_path):
            with open(self.batch_output_path, 'w') as file:
                file.write('')

    def __word_count(self, sentence):
        doc = self.nlp(sentence)
        word_count = sum(1 for token in doc if not token.is_punct)
        return word_count
    
    def __lemmarize(self, word, pos="n"):
        return self.lemmatizer.lemmatize(word.lower(), pos=pos)